#include "./graph_rt.h"

#include "./common.h"
#include "./helper.h"
#include "./ops.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/graph/cg.h"
#include "megbrain/imperative.h"
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/profiler_plugin.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/utility.h"
#include "megbrain/plugin/profiler.h"
#include "megbrain/serialization/serializer.h"

namespace py = pybind11;

using namespace mgb;
using namespace imperative;
namespace ser = mgb::serialization;

using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions;
using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform;
using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
using _SerializationMetadata = mgb::serialization::Metadata;
using _SerializationFormat = mgb::serialization::GraphDumpFormat;

namespace {
class _CompGraphProfilerImpl {
    std::shared_ptr<ComputingGraph> m_comp_graph;
    GraphProfiler m_profiler;

public:
    _CompGraphProfilerImpl(std::shared_ptr<ComputingGraph> cg)
            : m_comp_graph{cg}, m_profiler{m_comp_graph.get()} {}

    std::string _get_result() {
        auto json = m_profiler.to_json_full(m_comp_graph->current_comp_seq());
        return json->to_string();
    }
};

struct WeakRendezvousArray : public std::vector<std::weak_ptr<RendezvousBase>>,
                             public UserDataContainer::UserData {
    MGB_TYPEINFO_OBJ_DECL;
};
MGB_TYPEINFO_OBJ_IMPL(WeakRendezvousArray);
}  // namespace
#define DEF_READWRITE(name) .def_readwrite(#name, &CURRENT_CLASS::name)

template <typename T>
auto def_rendezvous(py::object m, const char* name) {
    return py::class_<Rendezvous<T>, std::shared_ptr<Rendezvous<T>>>(m, name)
            .def(py::init([]() { return Rendezvous<T>::make(); }))
            .def("set", [](Rendezvous<T>& r, T v) { r.set(std::move(v)); })
            .def(
                    "get", [](Rendezvous<T>& r) { return r.get(); },
                    py::call_guard<py::gil_scoped_release>())
            .def("drop", &Rendezvous<T>::drop)
            .def("reset", &Rendezvous<T>::reset)
            .def("set_exception", [](Rendezvous<T>& r, std::string&& message) {
                r.set_exception(std::make_exception_ptr(
                        std::runtime_error(std::move(message))));
            });
}

using TensorAttr = LogicalTensorDesc;
using HostNDWithEvent = std::pair<HostTensorND, std::shared_ptr<CompNode::Event>>;

std::vector<mgb::cg::VarNode*> _replace_vars(
        const std::vector<mgb::cg::VarNode*>& repl_src,
        const std::vector<mgb::cg::VarNode*>& repl_dst,
        const std::vector<mgb::cg::VarNode*>& vars) {
    mgb::ThinHashMap<SymbolVar, SymbolVar> varmap;
    for (size_t i = 0; i < repl_src.size(); ++i) {
        varmap[SymbolVar(repl_src[i])] = SymbolVar(repl_dst[i]);
    }
    SymbolVarArray symvars(vars.begin(), vars.end());
    auto sym_result = mgb::cg::replace_vars(symvars, varmap);
    std::vector<mgb::cg::VarNode*> result;
    for (auto symvar : sym_result) {
        result.push_back(symvar.node());
    }
    return result;
}

typedef std::vector<mgb::cg::OperatorNodeBase*> OperatorArray;
std::vector<mgb::cg::VarNode*> _replace_oprs(
        const OperatorArray& repl_src, const OperatorArray& repl_dst,
        const std::vector<mgb::cg::VarNode*>& vars) {
    mgb::ThinHashMap<mgb::cg::OperatorNodeBase*, mgb::cg::OperatorNodeBase*> oprmap;
    for (size_t i = 0; i < repl_src.size(); ++i) {
        oprmap[repl_src[i]] = repl_dst[i];
    }
    const SymbolVarArray symvars(vars.begin(), vars.end());
    auto sym_result = mgb::cg::replace_oprs(symvars, oprmap);
    std::vector<mgb::cg::VarNode*> result;
    for (auto symvar : sym_result) {
        result.push_back(symvar.node());
    }
    return result;
}

void _set_priority_to_id(const std::vector<mgb::cg::VarNode*>& dest_vars) {
    auto on_opr = [](mgb::cg::OperatorNodeBase* opr) {
        if (opr->node_prop().attribute().priority == 0) {
            opr->node_prop().attribute().priority = opr->id();
        }
    };
    mgb::cg::DepOprIter dep_iter{on_opr};
    for (const auto& var : dest_vars) {
        dep_iter.add(SymbolVar(var));
    }
}

py::object Py_Varnode = py::none();

void init_graph_rt(py::module m) {
    static const std::unique_ptr<mgb::OprFootprint> _imperative_sm_opr_footprint_ptr{
            std::make_unique<mgb::OprFootprint>()};

    def_rendezvous<DeviceTensorND>(m, "DeviceTensorNDRendezvous");

    def_rendezvous<HostNDWithEvent>(m, "HostTensorNDRendezvous");

    def_rendezvous<TensorAttr>(m, "TensorAttrRendezvous");

    Py_Varnode =
            py::class_<cg::VarNode, GraphNodePtr<cg::VarNode>>(m, "VarNode")
                    .def_property_readonly(
                            "owner", [](cg::VarNode* v) { return v->owner_opr(); })
                    .def_property_readonly(
                            "graph", [](cg::VarNode* v) { return v->owner_graph(); })
                    .def_property(
                            "name", py::overload_cast<>(&VarNode::name, py::const_),
                            py::overload_cast<std::string>(&VarNode::name))
                    .def_property_readonly(
                            "dtype", [](cg::VarNode* v) { return v->dtype(); })
                    .def_property_readonly(
                            "comp_node", [](cg::VarNode* v) { return v->comp_node(); })
                    .def_property_readonly(
                            "shape",
                            [](cg::VarNode* v) -> const TensorShape* {
                                auto&& mgr = v->owner_graph()->static_infer_manager();
                                return mgr.infer_shape_fallible(v);
                            })
                    .def_property_readonly(
                            "value",
                            [](cg::VarNode* v) -> py::object {
                                auto&& mgr = v->owner_graph()->static_infer_manager();
                                auto&& type = mgr.get_infer_type(v);
                                using InferType = cg::static_infer::InferType;
                                if (!(type.value &
                                      (InferType::CONST | InferType::RT_STATIC))) {
                                    return py::none();
                                }
                                auto* val = mgr.infer_value_fallible(v);
                                if (!val) {
                                    return py::none();
                                }
                                return py::cast(*val).attr("numpy")();
                            })
                    .def_property_readonly(
                            "id", [](cg::VarNode* v) { return (v->id()); })
                    .def("__repr__", [](cg::VarNode* v) { return "Var:" + v->name(); });

    py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>(
            m, "OperatorNode")
            .def_property_readonly(
                    "graph",
                    [](cg::OperatorNodeBase* opr) { return opr->owner_graph(); })
            .def_property(
                    "name",
                    py::overload_cast<>(&cg::OperatorNodeBase::name, py::const_),
                    py::overload_cast<std::string>(&cg::OperatorNodeBase::name))
            .def_property_readonly(
                    "inputs",
                    [](cg::OperatorNodeBase* opr) { return to_tuple(opr->input()); })
            .def_property_readonly(
                    "outputs",
                    [](cg::OperatorNodeBase* opr) {
                        return to_tuple(opr->usable_output());
                    })
            .def_property_readonly(
                    "id", [](cg::OperatorNodeBase* opr) { return opr->id(); })
            .def_property_readonly(
                    "params",
                    [](cg::OperatorNodeBase* opr) {
                        return _imperative_sm_opr_footprint_ptr->calc_footprint(opr)
                                .param->to_string();
                    })
            .def_property_readonly(
                    "type",
                    [](cg::OperatorNodeBase* opr) { return opr->dyn_typeinfo()->name; })
            .def("__repr__",
                 [](cg::OperatorNodeBase* opr) { return "Opr:" + opr->name(); })
            .def_property(
                    "priority",
                    [](cg::OperatorNodeBase* opr) {
                        return opr->node_prop().attribute().priority;
                    },
                    [](cg::OperatorNodeBase* opr, int priority) {
                        opr->node_prop().attribute().priority = priority;
                    });

    py::class_<cg::AsyncExecutable>(m, "AsyncExecutable")
            .def("execute", &cg::AsyncExecutable::execute,
                 py::call_guard<py::gil_scoped_release>())
            .def("wait", &cg::AsyncExecutable::wait,
                 py::call_guard<py::gil_scoped_release>())
            .def("get_prev_exec_time", &cg::AsyncExecutable::get_prev_exec_time,
                 py::call_guard<py::gil_scoped_release>())
            .def("_to_json",
                 [](cg::AsyncExecutable* exec) {
                     py::call_guard<py::gil_scoped_release>();
                     // dump currently compiled computing graph for debugging
                     return exec->to_json()->to_string();
                 })
            // only used for exception handle
            .def_property_readonly(
                    "_all_rendezvous",
                    [](cg::AsyncExecutable* exec) {
                        auto ud =
                                exec->owner_graph()
                                        ->options()
                                        .user_data.get_user_data<WeakRendezvousArray>();
                        std::vector<std::shared_ptr<RendezvousBase>> ret;
                        if (ud.second) {
                            for (auto&& r : *ud.first[0]) {
                                if (auto p = r.lock()) {
                                    ret.emplace_back(std::move(p));
                                }
                            }
                        }
                        return ret;
                    })
            .def("get_static_memory_alloc_info",
                 &cg::AsyncExecutable::get_static_memory_alloc_info,
                 py::call_guard<py::gil_scoped_release>());

    auto PyComputingGraph =
            py::class_<cg::ComputingGraph, std::shared_ptr<cg::ComputingGraph>>(
                    m, "ComputingGraph")
                    .def(py::init(py::overload_cast<>(&cg::ComputingGraph::make)))
                    .def("compile",
                         [](cg::ComputingGraph& graph,
                            const std::vector<cg::VarNode*>& dest_vars) {
                             mgb_assert(!dest_vars.empty());
                             cg::ComputingGraph::OutputSpec spec;
                             for (auto v : dest_vars) {
                                 spec.emplace_back(v, nullptr);
                             }
                             return graph.compile(spec);
                         })
                    .def("enable_weight_preprocess",
                         [](cg::ComputingGraph& graph) {
                             graph.options().graph_opt.enable_weight_preprocess();
                         })
                    .def_property_readonly(
                            "options",
                            py::overload_cast<>(&cg::ComputingGraph::options));

    py::class_<_CompGraphProfilerImpl, std::shared_ptr<_CompGraphProfilerImpl>>(
            m, "GraphProfiler")
            .def(py::init([](std::shared_ptr<ComputingGraph> graph) {
                return std::make_shared<_CompGraphProfilerImpl>(graph);
            }))
            .def("get", [](_CompGraphProfilerImpl& profiler) {
                return profiler._get_result();
            });

    using interpreter::intl::ProfilerPlugin;
    py::class_<ProfilerPlugin, std::shared_ptr<ProfilerPlugin>>(m, "GraphProfiler2")
            .def(py::init<cg::ComputingGraph*>());

    auto GraphOptimizeOptions =
            py::class_<_OptimizeForInferenceOptions>(m, "GraphOptimizeOptions")
                    .def(py::init())
                    .def("serialize", &_OptimizeForInferenceOptions::serialize)
                    .def_static(
                            "deserialize", &_OptimizeForInferenceOptions::deserialize)
                    .def_readwrite(
                            "f16_io_f32_comp",
                            &_OptimizeForInferenceOptions::f16_io_f32_comp)
                    .def_readwrite(
                            "f16_io_comp", &_OptimizeForInferenceOptions::f16_io_comp)
                    .def_readwrite(
                            "fuse_conv_bias_nonlinearity",
                            &_OptimizeForInferenceOptions::fuse_conv_bias_nonlinearity)
                    .def_readwrite(
                            "fuse_conv_bias_with_z",
                            &_OptimizeForInferenceOptions::fuse_conv_bias_with_z)
                    .def_readwrite(
                            "fuse_preprocess",
                            &_OptimizeForInferenceOptions::fuse_preprocess)
                    .def_readwrite(
                            "layout_transform",
                            &_OptimizeForInferenceOptions::layout_transform)
                    .def_readwrite(
                            "fuse_grain", &_OptimizeForInferenceOptions::fuse_grain);

    py::enum_<_LayoutTransform>(GraphOptimizeOptions, "LayoutTransform")
            .value("DEFAULT", _LayoutTransform::DEFAULT)
            .value("NCHW4", _LayoutTransform::NCHW4)
            .value("NHWCD4", _LayoutTransform::NHWCD4)
            .value("NCHW88", _LayoutTransform::NCHW88)
            .value("NCHW44", _LayoutTransform::NCHW44)
            .value("NCHW44_DOT", _LayoutTransform::NCHW44_DOT)
            .value("NCHW32", _LayoutTransform::NCHW32)
            .value("CHWN4", _LayoutTransform::CHWN4)
            .value("NCHW64", _LayoutTransform::NCHW64)
            .export_values();

    py::enum_<_SerializationFormat>(m, "SerializationFormat")
            .value("FBS", _SerializationFormat::FLATBUFFERS)
            .value("FBS_V2", _SerializationFormat::FLATBUFFERS_V2)
            .export_values();

    m.def("optimize_for_inference",
          [](const VarNodeArray& dest_vars, const _OptimizeForInferenceOptions& opt) {
              SymbolVarArray symvars(dest_vars.begin(), dest_vars.end());
              auto res_symvars = mgb::gopt::optimize_for_inference(symvars, opt);
              VarNodeArray vars;
              for (auto& si : res_symvars)
                  vars.push_back(si.node());
              return vars;
          });

    m.def("modify_opr_algo_strategy_inplace",
          [](const VarNodeArray& dest_vars, const _AlgoStrategy& strategy) {
              mgb::gopt::modify_opr_algo_strategy_inplace(dest_vars, strategy);
          });

    m.def("get_info_for_strip", [](const std::vector<VarNode*>& dest_vars) {
        std::unordered_set<const char*> opr_types, dtype_names, elemwise_modes;
        auto on_opr = [&](cg::OperatorNodeBase* opr) {
            if (ser::GraphDumper::should_remove_in_dump(opr))
                return;
            opr_types.insert(opr->dyn_typeinfo()->name);
            for (auto i : opr->output())
                dtype_names.insert(i->dtype().name());
            if (opr->same_type<opr::Elemwise>()) {
                auto mode = opr->cast_final<opr::Elemwise>().param().mode;
                elemwise_modes.insert(
                        megdnn::Elemwise::ModeTrait::from_mode(mode).name);
            }
        };
        cg::DepOprIter opr_iter{on_opr};
        for (auto i : dest_vars)
            opr_iter.add(i->owner_opr());

        auto to_json = [](const std::unordered_set<const char*>& v) {
            std::vector<std::string> vs(v.begin(), v.end());
            std::sort(vs.begin(), vs.end());
            auto ret = json::Array::make();
            for (auto&& i : vs)
                ret->add(json::String::make(i));
            return ret;
        };

        return json::Object::make({
                                          {"opr_types", to_json(opr_types)},
                                          {"dtypes", to_json(dtype_names)},
                                          {"elemwise_modes", to_json(elemwise_modes)},
                                  })
                ->to_string();
    });

    py::class_<_SerializationMetadata>(m, "SerializationMetadata")
            .def(py::init())
            .def_property(
                    "user_info",
                    [](const _SerializationMetadata& meta) {
                        return py::bytes(meta.get_user_info());
                    },
                    &_SerializationMetadata::set_user_info)
            .def_readonly(
                    "optimized_for_inference",
                    &_SerializationMetadata::optimized_for_inference)
            .def_property(
                    "optimize_options", &_SerializationMetadata::get_optimize_options,
                    &_SerializationMetadata::set_optimize_options)
            .def_readwrite("graph_modified", &_SerializationMetadata::graph_modified)
            .def_readwrite("is_valid", &_SerializationMetadata::is_valid);

    m.def("dump_graph",
          [](const std::vector<VarNode*>& dest_vars, int keep_var_name,
             bool keep_opr_name, bool keep_param_name, bool keep_opr_priority,
             bool no_change_graph, std::optional<_SerializationMetadata> metadata,
             std::optional<_SerializationFormat> dump_format,
             std::optional<int> model_version, py::list& stat, py::list& inputs,
             py::list& outputs, py::list& params) {
              std::vector<uint8_t> buf;
              ser::GraphDumpFormat format = ser::GraphDumpFormat::FLATBUFFERS_V2;
              int version = 2;
              if (dump_format.has_value()) {
                  format = dump_format.value();
              }
              if (model_version.has_value()) {
                  version = model_version.value();
              }
              auto dumper = ser::GraphDumper::make(
                      ser::OutputFile::make_vector_proxy(&buf), format, version);
              SymbolVarArray symvars(dest_vars.begin(), dest_vars.end());

              ser::GraphDumper::DumpConfig config{
                      keep_var_name, keep_param_name, keep_opr_priority, keep_opr_name};
              config.no_change_graph = no_change_graph;

              ser::GraphDumper::DumpResult rst;
              if (metadata)
                  rst = dumper->dump(symvars, config, *metadata);
              else
                  rst = dumper->dump(symvars, config);

              for (auto i : rst.inputs) {
                  inputs.append(py::cast(i));
              }
              for (auto i : rst.outputs) {
                  outputs.append(py::cast(i));
              }
              for (auto i : rst.params) {
                  params.append(py::cast(i));
              }
              auto rst_stat = std::vector{
                      rst.nr_opr, rst.tot_bytes, rst.tensor_value_bytes,
                      static_cast<size_t>(rst.content_hash)};
              for (auto i : rst_stat) {
                  stat.append(py::cast(i));
              }
              return py::bytes(reinterpret_cast<const char*>(&buf[0]), buf.size());
          });

    m.def("load_graph",
          [](std::string& buf, py::list& output_var_map, py::list& output_var_list) {
              auto file = ser::InputFile::make_mem_proxy(buf.c_str(), buf.length());
              auto format = ser::GraphLoader::identify_graph_dump_format(*file);
              auto loader = ser::GraphLoader::make(std::move(file), format.val());
              ser::GraphLoader::LoadConfig config;
              auto rst = loader->load(config);
              for (auto i : rst.output_var_map) {
                  output_var_map.append(py::make_tuple(i.first, i.second.node()));
              }
              for (auto i : rst.output_var_list) {
                  output_var_list.append(i.node());
              }
              std::unordered_map<HostTensorND*, const std::string*> tensor2name;
              for (const auto& pair : rst.tensor_map) {
                  tensor2name[pair.second.get()] = &pair.first;
              }
              auto cb = [&tensor2name, graph = rst.graph](cg::OperatorNodeBase* opr) {
                  if (!opr->same_type<opr::Host2DeviceCopy>())
                      return;
                  auto& h2d = opr->cast_final_safe<opr::Host2DeviceCopy>();
                  auto it = tensor2name.find(h2d.host_data().get());
                  mgb_throw_if(
                          it == tensor2name.end(), GraphError,
                          "unbound Host2DeviceCopy in loaded graph");
                  h2d.output(0)->name(*it->second);
              };
              cg::DepOprIter iter{cb};
              for (const auto& var : rst.output_var_list) {
                  iter.add(var);
              }
              auto ret = py::tuple(2);
              ret[0] = py::cast(rst.graph);
              ret[1] = py::cast(rst.metadata);
              return ret;
          });

#define CURRENT_CLASS cg::ComputingGraph::Options

    // clang-format off
    auto PyComputingGraphOptions =
            py::class_<cg::ComputingGraph::Options>(PyComputingGraph, "Options")
                // DEF_READWRITE(opr_attribute)
                DEF_READWRITE(seq_opt)
                DEF_READWRITE(graph_opt)
                DEF_READWRITE(graph_opt_level)
                DEF_READWRITE(log_level)
                DEF_READWRITE(async_exec_level)
                DEF_READWRITE(force_dynamic_alloc)
                DEF_READWRITE(var_sanity_check_first_run)
                DEF_READWRITE(allocate_static_mem_after_graph_compile)
                DEF_READWRITE(fake_next_exec)
                DEF_READWRITE(enable_sublinear_memory_opt)
                DEF_READWRITE(enable_dtr_memory_opt)
                DEF_READWRITE(no_profiling_on_shape_change)
                DEF_READWRITE(enable_var_mem_defragment)
                DEF_READWRITE(enable_grad_var_static_reshape)
                DEF_READWRITE(enable_memory_swap)
                DEF_READWRITE(comp_node_seq_record_level)
                DEF_READWRITE(no_force_inplace)
                DEF_READWRITE(sublinear_mem_config)
                DEF_READWRITE(dtr_config)
                // DEF_READWRITE(eager_evaluation)
                // DEF_READWRITE(imperative_proxy_graph)
                // DEF_READWRITE(extra_vardeps)
                // DEF_READWRITE(user_data)
            ;
    // clang-format on

#undef CURRENT_CLASS
#define CURRENT_CLASS cg::ComputingGraph::Options::SeqOpt

    py::class_<cg::ComputingGraph::Options::SeqOpt>(PyComputingGraphOptions, "SeqOpt")
            DEF_READWRITE(enable_mem_plan_opt) DEF_READWRITE(enable_mem_reuse_alloc)
                    DEF_READWRITE(enable_seq_comp_node_opt);

#undef CURRENT_CLASS
#define CURRENT_CLASS cg::ComputingGraph::Options::GraphOpt

    auto PyGraphOpt = py::class_<cg::ComputingGraph::Options::GraphOpt>(
            PyComputingGraphOptions, "GraphOpt") DEF_READWRITE(jit)
            DEF_READWRITE(jit_config)
            DEF_READWRITE(tensorrt);

#undef CURRENT_CLASS
#define CURRENT_CLASS cg::ComputingGraph::Options::GraphOpt::JITConfig

    py::class_<cg::ComputingGraph::Options::GraphOpt::JITConfig>(
            PyGraphOpt, "JITConfig") DEF_READWRITE(fuse_dimshuffle)
            DEF_READWRITE(fuse_reduce);

#undef CURRENT_CLASS
#define CURRENT_CLASS cg::ComputingGraph::Options::SublinearMemConfig

    py::class_<cg::ComputingGraph::Options::SublinearMemConfig>(
            PyComputingGraphOptions, "SublinearMemConfig") DEF_READWRITE(thresh_nr_try)
            DEF_READWRITE(genetic_nr_iter) DEF_READWRITE(genetic_pool_size)
                    DEF_READWRITE(lb_memory_mb) DEF_READWRITE(num_worker);

#undef CURRENT_CLASS

#define CURRENT_CLASS cg::ComputingGraph::Options::DTRConfig

    py::class_<cg::ComputingGraph::Options::DTRConfig>(
            PyComputingGraphOptions, "DTRConfig") DEF_READWRITE(eviction_threshold)
            DEF_READWRITE(evictee_minimum_size) DEF_READWRITE(recomp_memory_factor)
                    DEF_READWRITE(recomp_time_factor);

#undef CURRENT_CLASS
    auto common = rel_import("common", m, 1);

    common.def(
            "invoke_op",
            [](const OpDef& def, const std::vector<cg::VarNode*> inputs,
               cg::ComputingGraph* graph) {
                cg::VarNodeArray vinputs(inputs.begin(), inputs.end());
                return to_tuple(OpDef::apply_on_var_node(def, vinputs));
            },
            py::arg(), py::arg(), py::arg("graph") = py::none());

    auto input_callback = [](auto callback, const CompNode& comp_node,
                             const DType& dtype, const TensorShape& shape,
                             const std::vector<cg::VarNode*>& inputs,
                             cg::ComputingGraph* graph, bool use_static_shape) {
        if (!graph) {
            graph = inputs[0]->owner_graph();
        }
        SymbolVarArray sinputs;
        for (auto i : inputs) {
            sinputs.emplace_back(i);
        }
        static_assert(!std::is_reference<decltype(callback)>::value);
        auto soutputs = opr::InputCallback::make(
                *graph, std::move(callback), comp_node, dtype, shape, sinputs,
                use_static_shape);
        std::vector<VarNode*> outputs;
        outputs.reserve(soutputs.size());
        for (auto i : soutputs) {
            outputs.push_back(i.node());
        }
        return outputs;
    };

    m.def("make_shared", [](cg::ComputingGraph* graph, const DeviceTensorND& data) {
        return opr::SharedDeviceTensor::make(
                       *graph, std::make_shared<DeviceTensorND>(data))
                .node();
    });

    m.def(
            "make_const",
            [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype,
               std::optional<std::string> name) {
                if (!cn.valid()) {
                    cn = CompNode::load(get_default_device());
                }
                OperatorNodeConfig config(cn);
                if (name) {
                    config.name(*name);
                }
                auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype);
                return opr::ImmutableTensor::make(*graph, hv, config).node();
            },
            py::arg(), py::arg(), py::arg(), py::arg(), py::arg() = py::none());

    m.def(
            "make_h2d",
            [](cg::ComputingGraph& graph, CompNode cn, DType dtype, TensorShape shape,
               std::optional<std::string> name) {
                if (!cn.valid()) {
                    throw py::type_error("device must be valid");
                }
                if (!dtype.valid()) {
                    throw py::type_error("dtype must be valid");
                }
                OperatorNodeConfig config;
                if (name) {
                    config.name(*name);
                }
                return opr::Host2DeviceCopy::make(
                               graph, std::make_shared<HostTensorND>(cn, shape, dtype),
                               config)
                        .node();
            },
            py::arg(), py::arg(), py::arg(), py::arg() = py::none(),
            py::arg() = py::none());

    m.def("_replace_vars", &_replace_vars, py::arg(), py::arg(), py::arg());
    m.def("_replace_oprs", &_replace_oprs, py::arg(), py::arg(), py::arg());
    m.def("_set_priority_to_id", &_set_priority_to_id, py::arg());

    m.def(
            "input_callback",
            [input_callback](
                    std::function<DeviceTensorND(void)> callback,
                    const CompNode& comp_node, const DType& dtype,
                    const TensorShape& shape, const std::vector<cg::VarNode*>& inputs,
                    cg::ComputingGraph* graph, bool use_static_shape) {
                return input_callback(
                        [f = std::move(callback)]() {
                            py::gil_scoped_acquire _;
                            return f();
                        },
                        comp_node, dtype, shape, inputs, graph, use_static_shape);
            },
            py::arg(), py::arg(), py::arg(), py::arg() = py::none(),
            py::arg() = py::tuple(), py::arg("graph") = py::none(),
            py::arg("use_static_shape") = false);

    m.def(
            "input_callback",
            [input_callback](
                    std::shared_ptr<Rendezvous<DeviceTensorND>> p,
                    const CompNode& comp_node, const DType& dtype,
                    const TensorShape& shape, const std::vector<cg::VarNode*>& inputs,
                    cg::ComputingGraph* graph, bool use_static_shape) {
                auto f = [p]() -> DeviceTensorND { return p->get(); };
                return input_callback(
                        std::move(f), comp_node, dtype, shape, inputs, graph,
                        use_static_shape);
            },
            py::arg(), py::arg(), py::arg(), py::arg() = py::none(),
            py::arg() = py::tuple(), py::arg("graph") = py::none(),
            py::arg("use_static_shape") = false);

    auto output_callback = [](auto callback, const std::vector<cg::VarNode*>& inputs,
                              std::shared_ptr<RendezvousBase> r = {},
                              bool borrow = false, bool prefer_host_value = false) {
        if (r) {
            mgb_assert(inputs.size());
            auto cg = inputs[0]->owner_graph();
            cg->options()
                    .user_data.get_user_data_or_create<WeakRendezvousArray>()
                    ->emplace_back(r);
        }
        SymbolVarArray sinputs;
        for (auto i : inputs) {
            sinputs.emplace_back(i);
        }
        static_assert(!std::is_reference<decltype(callback)>::value);
        opr::OutputCallback::Param param{
                std::move(callback), borrow, prefer_host_value};
        auto output = opr::OutputCallback::make(std::move(param), sinputs);
        return output.node();
    };

    m.def("output_callback", [output_callback](
                                     std::function<void(DeviceTensorND)> callback,
                                     std::vector<cg::VarNode*> inputs) {
        auto f = [f = std::move(callback)](DeviceTensorND dv) {
            auto task = [f = std::move(f), dv = std::move(dv)]() { f(dv); };
            py_task_q.add_task(std::move(task));
        };
        return output_callback(std::move(f), std::move(inputs));
    });

    m.def("output_callback", [output_callback](
                                     std::shared_ptr<Rendezvous<DeviceTensorND>> p,
                                     std::vector<cg::VarNode*> inputs) {
        auto f = [p](DeviceTensorND dv) { p->set(std::move(dv)); };
        return output_callback(std::move(f), std::move(inputs), p);
    });

    m.def("value_output_callback",
          [output_callback](
                  std::shared_ptr<Rendezvous<HostNDWithEvent>> p,
                  std::vector<cg::VarNode*> inputs) {
              auto f = [p](DeviceTensorND dv) {
                  HostNDWithEvent hv_with_event;
                  hv_with_event.first.copy_from(dv);
                  hv_with_event.second = dv.comp_node().create_event();
                  hv_with_event.second->record();
                  p->set(std::move(hv_with_event));
              };
              return output_callback(std::move(f), std::move(inputs), p, true, true);
          });

    m.def("attr_output_callback", [output_callback](
                                          std::shared_ptr<Rendezvous<TensorAttr>> p,
                                          std::vector<cg::VarNode*> inputs) {
        auto f = [p](DeviceTensorND dv) {
            p->set(TensorAttr{TensorLayout{dv.shape(), dv.dtype()}, dv.comp_node()});
        };
        return output_callback(std::move(f), std::move(inputs), p, true);
    });

    m.def("virtual_dep", [](std::vector<cg::VarNode*> inputs, std::string device) {
        auto&& graph = inputs[0]->owner_graph();
        VarNodeArray inps(inputs.begin(), inputs.end());
        cg::OperatorNodeConfig config;
        if (device.length() > 0) {
            config.comp_node(CompNode::load(device));
        }
        cg::OperatorNodeBase* opr =
                graph->insert_opr(std::make_unique<mgb::opr::VirtualDep>(inps, config));
        return opr;
    });
}
